{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "1105.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/1105.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2IB77QOvsP0J",
        "colab_type": "text"
      },
      "source": [
        "## 4.1 模型构造"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "IqSzsL_qsuLO",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install mxnet d2lzh"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wuq1Gdbm6SQd",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from mxnet import nd\n",
        "from mxnet.gluon import nn\n",
        "\n",
        "class MLP(nn.Block):\n",
        "    def __init__(self, **kwargs):\n",
        "        super(MLP, self).__init__(**kwargs)\n",
        "        self.hidden = nn.Dense(256, activation='relu')  # 隐藏层\n",
        "        self.output = nn.Dense(10)  # 输出层\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.output(self.hidden(x))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jFvsEniqGFbK",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 122
        },
        "outputId": "ae4f8151-a138-4cef-bc70-ad8ba3ef8de0"
      },
      "source": [
        "X = nd.random.uniform(shape=(2, 20))\n",
        "net = MLP()\n",
        "net.initialize()\n",
        "net(X)"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[[-0.0822941   0.07456148  0.06385002 -0.04937973  0.0698424   0.02713358\n",
              "  -0.07743802 -0.00322946  0.02611737 -0.04490523]\n",
              " [-0.06431118  0.07135497  0.09599175 -0.06451458  0.09757216  0.04676059\n",
              "  -0.04679815  0.02725542  0.04404476 -0.04410914]]\n",
              "<NDArray 2x10 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QaoABMjm3YI5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class MySequential(nn.Block):\n",
        "    def __init__(self, **kwargs):\n",
        "        super(MySequential, self).__init__(**kwargs)\n",
        "\n",
        "    def add(self, block):\n",
        "        self._children[block.name] = block\n",
        "\n",
        "    def forward(self, x):\n",
        "        for block in self._children.values():\n",
        "            x = block(x)\n",
        "        return x "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z4uSfz-M5uYi",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 122
        },
        "outputId": "cce43ea2-3d52-4f46-f375-c522f537a6d6"
      },
      "source": [
        "net = MySequential()\n",
        "net.add(nn.Dense(256, activation='relu'))\n",
        "net.add(nn.Dense(10))\n",
        "net.initialize()\n",
        "net(X)"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[[-0.01043473  0.01791799  0.02722769  0.01189875 -0.03495108  0.07704126\n",
              "   0.00182015 -0.00797976 -0.043986   -0.02848949]\n",
              " [-0.02556653  0.01028888  0.02033464  0.01762148 -0.0338972   0.03672597\n",
              "   0.03781967 -0.03908284 -0.05747136  0.01365921]]\n",
              "<NDArray 2x10 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6JLUJHOP6NyE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class FancyMLP(nn.Block):\n",
        "    def __init__(self, **kwargs):\n",
        "        super(FancyMLP, self).__init__(**kwargs)\n",
        "        self.rand_weight = self.params.get_constant('rand_weight', nd.random.uniform(shape=(20, 20)))\n",
        "        self.dense = nn.Dense(20, activation='relu')\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.dense(x)\n",
        "        x = nd.relu(nd.dot(x, self.rand_weight.data()) + 1)\n",
        "        x = self.dense(x)\n",
        "        while x.norm().asscalar() > 1:\n",
        "            x /= 2 \n",
        "        if x.norm().asscalar() < 0.8:\n",
        "            x += 10 \n",
        "        return x.sum()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "QSYldX7TniGp",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "outputId": "f3f2b781-cbdf-4d15-f450-6a5c13841180"
      },
      "source": [
        "net = FancyMLP()\n",
        "net.initialize()\n",
        "net(X)"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[403.08377]\n",
              "<NDArray 1 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "prtocDacnoCQ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class NestMLP(nn.Block):\n",
        "    def __init__(self, **kwargs):\n",
        "        super(NestMLP, self).__init__(**kwargs)\n",
        "        self.net = nn.Sequential()\n",
        "        self.net.add(nn.Dense(64, activation='relu'), \n",
        "               nn.Dense(32, activation='relu'))\n",
        "        self.dense = nn.Dense(16, activation='relu')\n",
        "\n",
        "    def forward(self, x):\n",
        "        return self.dense(self.net(x))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gF66XV0woiku",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "outputId": "b0020777-6ab4-44e0-d510-f78134cd0657"
      },
      "source": [
        "net = nn.Sequential()\n",
        "net.add(NestMLP(), nn.Dense(20), FancyMLP())\n",
        "\n",
        "net.initialize()\n",
        "net(X)"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[3.7801871]\n",
              "<NDArray 1 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 24
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RhZtEIKXowyp",
        "colab_type": "text"
      },
      "source": [
        "## 4.2 模型参数的访问、初始化和共享"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EV5CC_DQpSRg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from mxnet import init, nd \n",
        "from mxnet.gluon import nn \n",
        "\n",
        "net = nn.Sequential()\n",
        "net.add(nn.Dense(256, activation='relu'))\n",
        "net.add(nn.Dense(10))\n",
        "net.initialize()\n",
        "\n",
        "X = nd.random.uniform(shape=(2, 20))\n",
        "Y = net(X)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ogGGCpK_puqs",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 87
        },
        "outputId": "90699b30-78cf-41b0-9537-2555ca8736aa"
      },
      "source": [
        "net[0].params"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "dense19_ (\n",
              "  Parameter dense19_weight (shape=(256, 20), dtype=float32)\n",
              "  Parameter dense19_bias (shape=(256,), dtype=float32)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "y0fb8sjkp2Vt",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "60d33ee9-3150-4a7a-81dd-f6c9d707e9f3"
      },
      "source": [
        "type(net[0].params)"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "mxnet.gluon.parameter.ParameterDict"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VLuerqlhp92y",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "b33ae77b-9684-43e1-8248-29255a515c48"
      },
      "source": [
        "net[0].params['dense19_weight']"
      ],
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Parameter dense19_weight (shape=(256, 20), dtype=float32)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 28
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Tj5aqiqRqNt_",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "32903b91-a034-496f-b55e-3a415e519638"
      },
      "source": [
        "net[0].weight"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Parameter dense19_weight (shape=(256, 20), dtype=float32)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 29
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "I-IZD_vyqQ3M",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 281
        },
        "outputId": "b2fc9e3b-92e8-4010-9188-2bcd58364c01"
      },
      "source": [
        "net[0].weight.data()"
      ],
      "execution_count": 30,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[[-0.06763797  0.02341988  0.02750754 ... -0.0163656   0.0271539\n",
              "   0.04293371]\n",
              " [-0.057276   -0.00545905 -0.03811367 ...  0.017037    0.00038091\n",
              "  -0.01689015]\n",
              " [ 0.05912871  0.00589813  0.00579332 ... -0.03548513  0.06329998\n",
              "   0.00878964]\n",
              " ...\n",
              " [-0.04247909 -0.05434325  0.04364657 ... -0.00520646  0.04954348\n",
              "   0.0615086 ]\n",
              " [ 0.01628537 -0.00393314  0.06814129 ...  0.06664989 -0.06824005\n",
              "   0.02902017]\n",
              " [ 0.06353932 -0.03678216  0.03699008 ...  0.03383283 -0.04630034\n",
              "   0.06421676]]\n",
              "<NDArray 256x20 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 30
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hzzhzHP-1RzC",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 175
        },
        "outputId": "bd95fc17-2b55-40b7-e131-03330cc56109"
      },
      "source": [
        "net[0].weight.grad()"
      ],
      "execution_count": 31,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[[0. 0. 0. ... 0. 0. 0.]\n",
              " [0. 0. 0. ... 0. 0. 0.]\n",
              " [0. 0. 0. ... 0. 0. 0.]\n",
              " ...\n",
              " [0. 0. 0. ... 0. 0. 0.]\n",
              " [0. 0. 0. ... 0. 0. 0.]\n",
              " [0. 0. 0. ... 0. 0. 0.]]\n",
              "<NDArray 256x20 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 31
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "DYn7Q9Nh1Zhs",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "outputId": "bae76d2f-493d-42bb-9e04-acd7e1367305"
      },
      "source": [
        "net[1].bias.data()"
      ],
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]\n",
              "<NDArray 10 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 32
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LYCdfEch1e-x",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 122
        },
        "outputId": "5178ef19-5477-4802-a622-ff67d0548125"
      },
      "source": [
        "net.collect_params()"
      ],
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "sequential4_ (\n",
              "  Parameter dense19_weight (shape=(256, 20), dtype=float32)\n",
              "  Parameter dense19_bias (shape=(256,), dtype=float32)\n",
              "  Parameter dense20_weight (shape=(10, 256), dtype=float32)\n",
              "  Parameter dense20_bias (shape=(10,), dtype=float32)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 33
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dzM3wmO_1iQZ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 87
        },
        "outputId": "2f02bb34-1a38-4b34-c84f-a04e1937bafd"
      },
      "source": [
        "net.collect_params('.*weight')"
      ],
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "sequential4_ (\n",
              "  Parameter dense19_weight (shape=(256, 20), dtype=float32)\n",
              "  Parameter dense20_weight (shape=(10, 256), dtype=float32)\n",
              ")"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 34
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6Z5jS8Nk1xHi",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 122
        },
        "outputId": "15fb6e77-e363-4d29-8f55-e5cbc33822d8"
      },
      "source": [
        "net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)\n",
        "net[0].weight.data()[0]"
      ],
      "execution_count": 35,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[ 0.00632535  0.02281385 -0.00790374  0.00079064 -0.00359422  0.00275971\n",
              "  0.01268833  0.01220814 -0.01035277 -0.00453738 -0.00711486  0.00247771\n",
              "  0.01414129 -0.00594405  0.00053162 -0.00421787 -0.01071319 -0.00609468\n",
              "  0.00426226 -0.00555236]\n",
              "<NDArray 20 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 35
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aLrlnL823S4U",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "outputId": "cc2f526e-b578-4af8-f116-0f6e5b672610"
      },
      "source": [
        "net.initialize(init.Constant(1), force_reinit=True)\n",
        "net[0].weight.data()[0]"
      ],
      "execution_count": 36,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]\n",
              "<NDArray 20 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 36
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dAu0xK-k3l3h",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 122
        },
        "outputId": "d53b3fcf-3294-4273-b3b5-fad8fb854b24"
      },
      "source": [
        "net[0].weight.initialize(init=init.Xavier(), force_reinit=True)\n",
        "net[0].weight.data()[0]"
      ],
      "execution_count": 37,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[-0.08361219  0.08051363 -0.09847713  0.07271375  0.12466413  0.06529617\n",
              " -0.06072348 -0.00504498 -0.01383175 -0.09022775 -0.00178175 -0.08682375\n",
              "  0.08202833  0.01045902  0.10150935 -0.01833758 -0.10643165  0.14006686\n",
              " -0.02155472 -0.05873382]\n",
              "<NDArray 20 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 37
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ax1CtNiI4HpV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class MyInit(init.Initializer):\n",
        "    def _init_weight(self, name, data):\n",
        "        print('Init', name, data.shape)\n",
        "        data[:] = nd.random.uniform(low=-10, high=10, shape=data.shape)\n",
        "        data *= data.abs() >= 5 "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YTq2i1Jj4qMq",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 158
        },
        "outputId": "d94ed2c8-23a3-4d39-c22a-a7ccea7f2028"
      },
      "source": [
        "net.initialize(MyInit(), force_reinit=True)\n",
        "net[0].weight.data()[0]"
      ],
      "execution_count": 39,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Init dense19_weight (256, 20)\n",
            "Init dense20_weight (10, 256)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[ 6.1872997  0.         0.         0.         9.086678  -0.\n",
              " -0.         0.         7.9508553  0.         5.3993435 -0.\n",
              " -0.         0.         0.        -0.        -0.        -8.298664\n",
              "  7.487999  -5.336525 ]\n",
              "<NDArray 20 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 39
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Pz_ZcxGQ41fS",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 122
        },
        "outputId": "f08f2840-517d-4bac-9609-5b18907c1ab4"
      },
      "source": [
        "net[0].weight.set_data(net[0].weight.data() + 1)\n",
        "net[0].weight.data()[0]"
      ],
      "execution_count": 40,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[ 7.1872997  1.         1.         1.        10.086678   1.\n",
              "  1.         1.         8.950855   1.         6.3993435  1.\n",
              "  1.         1.         1.         1.         1.        -7.298664\n",
              "  8.487999  -4.336525 ]\n",
              "<NDArray 20 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 40
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lG4F-uqe5PWc",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 70
        },
        "outputId": "954365e4-5573-4e20-d689-3e77b4b43343"
      },
      "source": [
        "net = nn.Sequential()\n",
        "shared = nn.Dense(8, activation='relu')\n",
        "net.add(nn.Dense(8, activation='relu'), \n",
        "    shared, \n",
        "    nn.Dense(8, activation='relu', params=shared.params), \n",
        "    nn.Dense(10))\n",
        "net.initialize()\n",
        "\n",
        "X = nd.random.uniform(shape=(2, 10))\n",
        "net(X)\n",
        "\n",
        "net[1].weight.data()[0] == net[2].weight.data()[0]"
      ],
      "execution_count": 42,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "\n",
              "[1. 1. 1. 1. 1. 1. 1. 1.]\n",
              "<NDArray 8 @cpu(0)>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 42
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fTKkEqpp6b97",
        "colab_type": "text"
      },
      "source": [
        "## 4.3 模型参数的延后初始化"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_1P-upsl61ik",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from mxnet import init, nd \n",
        "from mxnet.gluon import nn \n",
        "\n",
        "class MyInit(init.Initializer):\n",
        "    def _init_weight(self, name, data):\n",
        "        print('Init', name, data.shape)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HHCcmbSwOMLw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "net = nn.Sequential()\n",
        "net.add(nn.Dense(256, activation='relu'), \n",
        "    nn.Dense(10))\n",
        "\n",
        "net.initialize(init=MyInit())\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gDDetm4cOqt1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "5deaace2-f5a0-4947-b068-e66260a92742"
      },
      "source": [
        "X = nd.random.uniform(shape=(2, 20))\n",
        "Y = net(X)"
      ],
      "execution_count": 46,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Init dense27_weight (256, 20)\n",
            "Init dense28_weight (10, 256)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BYIH5jggO2Bk",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "Y = net(X)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "H-01Vic8O8_R",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "4cbae25e-436f-410e-d472-9573fd811b13"
      },
      "source": [
        "net.initialize(init=MyInit(), force_reinit=True)"
      ],
      "execution_count": 48,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Init dense27_weight (256, 20)\n",
            "Init dense28_weight (10, 256)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0pVV9jJtPWkZ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 52
        },
        "outputId": "3b93d59d-4c6e-4af4-ea76-9cfba4809764"
      },
      "source": [
        "net = nn.Sequential()\n",
        "net.add(nn.Dense(256, in_units=20, activation='relu'))\n",
        "net.add(nn.Dense(10, in_units=256))\n",
        "\n",
        "net.initialize(init=MyInit())"
      ],
      "execution_count": 49,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Init dense29_weight (256, 20)\n",
            "Init dense30_weight (10, 256)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0MvkPjQ5PxsP",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}